package com.lock.zklock;

import com.google.common.collect.Lists;
import com.lock.ZLock;
import org.apache.commons.lang3.StringUtils;
import org.apache.curator.RetryLoop;
import org.apache.curator.framework.CuratorFramework;
import org.apache.curator.framework.imps.CuratorFrameworkState;
import org.apache.curator.framework.recipes.locks.PredicateResults;
import org.apache.curator.utils.ZKPaths;
import org.apache.zookeeper.CreateMode;
import org.apache.zookeeper.KeeperException;
import org.apache.zookeeper.WatchedEvent;
import org.apache.zookeeper.Watcher;

import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * ZK 阻塞 重入锁
 * @author wanghaitao
 */
public class ZKReentrantLock implements ZLock {
	// private final Logger log = LoggerFactory.getLogger(this.getClass());
	public final static String UNDER_LINE = "_";
	public final static String PREFIX_MID_PATH = "zklocking2";
	public final static String BASE_PATH = "/daling_zk";

	private final CuratorFramework client;
	private String baseLockPath = ZKPaths.makePath(BASE_PATH, PREFIX_MID_PATH);;
	private final String lockName;
	private final String lockPath;
	private String selfPath;
	private final String uuidValue;

	private final ThreadLocal<MyLockData> threadData = new ThreadLocal<MyLockData>();

	public ZKReentrantLock(CuratorFramework client, String thisLockName) {
		this.client = client;
		this.lockName = thisLockName + UNDER_LINE;
		this.lockPath = ZKPaths.makePath(baseLockPath, thisLockName + UNDER_LINE);
		this.uuidValue = UUID.randomUUID().toString();
	}

	public ZKReentrantLock(CuratorFramework client, String basePath, String thisLockName) {
		this.client = client;
		this.baseLockPath = basePath;
		this.lockName = thisLockName + UNDER_LINE;
		this.lockPath = ZKPaths.makePath(basePath, thisLockName + UNDER_LINE);
		this.uuidValue = UUID.randomUUID().toString();
	}

	public String getBaseLockPath() {
		return baseLockPath;
	}

	public void setBaseLockPath(String baseLockPath) {
		this.baseLockPath = baseLockPath;
	}

	@Override
	public void lock() throws Exception {
		if (!tryLock(-1, null)) {
			throw new Exception("Lost connection while trying to acquire lock: " + baseLockPath);
		}

	}

	@Override
	public boolean tryLock() throws Exception {
		return tryLock(0L, TimeUnit.MILLISECONDS);
	}

	@Override
	public boolean tryLock(Duration waitTime) throws Exception {
		return tryLock(waitTime.toMillis(), TimeUnit.MILLISECONDS);
	}

	@Override
	public boolean tryLock(long time, TimeUnit unit) throws Exception {
		MyLockData lockData = threadData.get();
		if (lockData != null) {
			// re-entering
			lockData.lockCount.incrementAndGet();
			return true;
		} else {

			String lockPath = attemptLock(time, unit);
			if (lockPath != null) {
				threadData.set(new MyLockData(uuidValue, lockPath));
				return true;
			} else {
				return false;
			}
		}
	}

	/**
	 * 
	 * 创建临时节点, 等待获取锁
	 * 
	 * @param time
	 * @param unit
	 * @return
	 * @throws Exception
	 */
	String attemptLock(long time, TimeUnit unit) throws Exception {
		final long startMillis = System.currentTimeMillis();
		final Long millisToWait = (unit != null) ? unit.toMillis(time) : null;
		int retryCount = 0;

		boolean hasTheLock = false;
		boolean isDone = false;
		while (!isDone) {
			isDone = true;

			try {
				selfPath = client.create().creatingParentsIfNeeded()// .withProtection()
						.withMode(CreateMode.EPHEMERAL_SEQUENTIAL).forPath(lockPath);
				// log.info("{} 创建锁路径:{}", myID, selfPath);

				hasTheLock = internalLockLoop(startMillis, millisToWait);
			} catch (KeeperException.NoNodeException e) {
				if (client.getZookeeperClient().getRetryPolicy()
						.allowRetry(retryCount++, System.currentTimeMillis() - startMillis, RetryLoop.getDefaultRetrySleeper())) {
					isDone = false;
				} else {
					throw e;
				}
			}
		}

		if (hasTheLock) {
			return selfPath;
		}

		return null;
	}

	/**
	 * 
	 * 循环检查临时节点队列, 获取自己为序号最小节点后, 获得锁退出
	 * 
	 * @param startMillis
	 * @param millisToWait
	 * @return
	 * @throws Exception
	 */
	private boolean internalLockLoop(long startMillis, Long millisToWait) throws Exception {
		boolean haveTheLock = false;
		boolean doDelete = false;
		try {
			while ((client.getState() == CuratorFrameworkState.STARTED) && !haveTheLock) {
				PredicateResults checkMinPathResult = checkMinPath();
				// log.debug("{} Results:getsTheLock:{}, Results:getPathToWatch:{}",
				// myID, getLockResults.getsTheLock(),
				// getLockResults.getPathToWatch());
				if (checkMinPathResult.getsTheLock()) {
					haveTheLock = true;
				} else {
					String waitPath = checkMinPathResult.getPathToWatch();
					// log.debug("{} waitPath:{}", myID, waitPath);

					synchronized (this) {
						try {
							// 用 getData()替代exists()以避免泄露watchers资源
							client.getData().usingWatcher(watcher).forPath(waitPath);
							if (millisToWait != null) {
								millisToWait -= (System.currentTimeMillis() - startMillis);
								startMillis = System.currentTimeMillis();
								if (millisToWait <= 0) {
									doDelete = true; // 过期, 删除当前锁节点
									break;
								}
								wait(millisToWait);
							} else {
								wait();
							}
						} catch (KeeperException.NoNodeException e) {
							// it has been deleted (i.e. lock released). Try to acquire again
						}
					}
				}
			}
		} catch (Exception e) {
			doDelete = true;
			throw e;
		} finally {
			if (doDelete) {
				innerReleaseLock(selfPath);
			}
		}
		return haveTheLock;
	}

	/**
	 * 检查自己是不是最小的节点
	 * 
	 * @return
	 * @throws Exception
	 */
	public PredicateResults checkMinPath() throws Exception {
		String waitPath = null;

		List<String> children = client.getChildren().forPath(baseLockPath);
		// log.info("{} basePath={}, children={}", myID, basePath, children);
		// 要过滤出lockName开头的节点, 其他lock也在同一basePath目录下, 不需要参与排序和计算
		Lists.newArrayList(children);
		List<String> subNodes = Lists.newArrayList();
		int indx = 0;
		for (String item : children) {
			// indx = item.lastIndexOf(UNDER_LINE);
			if ((indx = item.lastIndexOf(UNDER_LINE)) > 0 && StringUtils.equals(item.substring(0, indx + 1), lockName)) {
				subNodes.add(item);
			}
		}
		//
		Collections.sort(subNodes);
		// log.info("{} basePath={}, subNodes={}", myID, basePath, subNodes);

		// log.info("{} basePath={}, subNodes={}", myID, basePath, subNodes);

		int lockIndex = subNodes.indexOf(selfPath.substring(baseLockPath.length() + 1));
		// log.info("{} subNodes.indexOf={},index={}", myID,
		// selfPath.substring(basePath.length() + 1), index);

		switch (lockIndex) {
		case -1: {
			// log.error(myID + " does not exist..." + selfPath);
			throw new KeeperException.NoNodeException("Sequential path not found: " + selfPath);
		}
		case 0: {
			// log.debug(myID + " is first node " + selfPath);
			return new PredicateResults(null, true);
		}
		default: {
			waitPath = baseLockPath + "/" + subNodes.get(lockIndex - 1);
			// log.debug(myID + " get child list, previous node is" + waitPath);
			return new PredicateResults(waitPath, false);
		}

		}

	}

	@Override
	public boolean isLocked() {
		MyLockData lockData = threadData.get();
		return (lockData != null && StringUtils.equals(lockData.trackID, uuidValue));
	}

	@Override
	public boolean isHoldByCurrentThread() {
		MyLockData lockData = threadData.get();
		return (lockData != null && StringUtils.equals(lockData.trackID, uuidValue));
	}

	@Override
	public void unlock() throws Exception {

		MyLockData lockData = threadData.get();
		if (lockData == null) {
			return;
			// throw new IllegalMonitorStateException("You do not own the lock: " + baseLockPath);
		}

		int newLockCount = lockData.lockCount.decrementAndGet();

		if (newLockCount > 0) {
			// 重入锁未完全解锁
			return;
		} else if (newLockCount < 0) {
			throw new IllegalMonitorStateException("Lock count has gone negative for lock: " + baseLockPath);
		}

		try {
			innerReleaseLock(lockData.lockPath);
		} finally {
			threadData.remove();
		}

	}

	@Override
	public void close() {
		try {
			unlock();
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

	private void innerReleaseLock(String lockPath) throws Exception {
		try {
			client.delete().guaranteed().forPath(lockPath);
		} catch (KeeperException.NoNodeException e) {
			// ignore - already deleted (possibly expired session, etc.)
		}
	}

	@Override
	public int getHoldCount() {
		MyLockData lockData = threadData.get();
		return lockData == null ? 0 : lockData.lockCount.get();
	}

	/**
	 * 唤醒其他等待锁的线程
	 */
	private synchronized final void notifyFromWatcher() {
		notifyAll();
	}

	private static class MyLockData {
		final String trackID;
		final String lockPath;
		final AtomicInteger lockCount = new AtomicInteger(1);

		private MyLockData(String trackID, String lockPath) {
			this.trackID = trackID;
			this.lockPath = lockPath;
		}
	}

	private final Watcher watcher = new Watcher() {
		@Override
		public void process(WatchedEvent event) {
			notifyFromWatcher();
		}
	};

	// private PredicateResults getsTheLock(CuratorFramework client, List<String> children, String sequenceNodeName) throws Exception {
	// int ourIndex = children.indexOf(sequenceNodeName);
	// if (ourIndex < 0) {
	// throw new KeeperException.NoNodeException("Sequential path not found: " + sequenceNodeName);
	// }
	//
	// boolean getsTheLock = ourIndex < 1;
	// String pathToWatch = getsTheLock ? null : children.get(ourIndex - 1);
	//
	// return new PredicateResults(pathToWatch, getsTheLock);
	// }

}
